{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scipy\n",
    "import xarray as xr\n",
    "from skdownscale.pointwise_models import BcsdPrecipitation, BcsdTemperature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# utilities for plotting cdfs\n",
    "def plot_cdf(ax=None, **kwargs):\n",
    "    if ax:\n",
    "        plt.sca(ax)\n",
    "    else:\n",
    "        ax = plt.gca()\n",
    "\n",
    "    for label, X in kwargs.items():\n",
    "        vals = np.sort(X, axis=0)\n",
    "        pp = scipy.stats.mstats.plotting_positions(vals)\n",
    "        ax.plot(pp, vals, label=label)\n",
    "    ax.legend()\n",
    "    return ax\n",
    "\n",
    "\n",
    "def plot_cdf_by_month(ax=None, **kwargs):\n",
    "    fig, axes = plt.subplots(4, 3, sharex=True, sharey=False, figsize=(12, 8))\n",
    "\n",
    "    for label, X in kwargs.items():\n",
    "        for month, ax in zip(range(1, 13), axes.flat):\n",
    "\n",
    "            vals = np.sort(X[X.index.month == month], axis=0)\n",
    "            pp = scipy.stats.mstats.plotting_positions(vals)\n",
    "            ax.plot(pp, vals, label=label)\n",
    "            ax.set_title(month)\n",
    "    ax.legend()\n",
    "    return ax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# open a small dataset for training\n",
    "training = xr.open_zarr(\"../data/downscale_test_data.zarr.zip\", group=\"training\")\n",
    "training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# open a small dataset of observations (targets)\n",
    "targets = xr.open_zarr(\"../data/downscale_test_data.zarr.zip\", group=\"targets\")\n",
    "targets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# extract 1 point of training data for precipitation and temperature\n",
    "X_temp = training.isel(point=0).to_dataframe()[[\"T2max\"]].resample(\"MS\").mean() - 273.13\n",
    "X_pcp = training.isel(point=0).to_dataframe()[[\"PREC_TOT\"]].resample(\"MS\").sum() * 24\n",
    "display(X_temp.head(), X_pcp.head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# extract 1 point of target data for precipitation and temperature\n",
    "y_temp = targets.isel(point=0).to_dataframe()[[\"Tmax\"]].resample(\"MS\").mean()\n",
    "y_pcp = targets.isel(point=0).to_dataframe()[[\"Prec\"]].resample(\"MS\").sum()\n",
    "display(y_temp.head(), y_pcp.head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fit/predict the BCSD Temperature model\n",
    "bcsd_temp = BcsdTemperature()\n",
    "bcsd_temp.fit(X_temp, y_temp)\n",
    "out = bcsd_temp.predict(X_temp) + X_temp\n",
    "plot_cdf(X=X_temp, y=y_temp, out=out)\n",
    "out.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_cdf_by_month(X=X_temp, y=y_temp, out=out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fit/predict the BCSD Precipitation model\n",
    "bcsd_pcp = BcsdPrecipitation()\n",
    "bcsd_pcp.fit(X_pcp, y_pcp)\n",
    "out = bcsd_pcp.predict(X_pcp) * X_pcp\n",
    "plot_cdf(X=X_pcp, y=y_pcp, out=out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_cdf_by_month(X=X_pcp, y=y_pcp, out=out)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
